//  Copyright (c) 2017-2018 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package broker

import (
	memCom "github.com/uber/aresdb/memstore/common"
	metaCom "github.com/uber/aresdb/metastore/common"
	"github.com/uber/aresdb/query/common"
	"github.com/uber/aresdb/query/expr"
	"github.com/uber/aresdb/utils"
	"net/http"
	"github.com/uber/aresdb/query/context"
)

const (
	nonAggregationQueryLimit = 1000
)

// QueryContext is broker query context
type QueryContext struct {
	AQLQuery              *common.AQLQuery
	IsNonAggregationQuery bool
	ReturnHLLBinary       bool
	Writer                http.ResponseWriter
	Error                 error
	Tables                []*memCom.TableSchema
	TableIDByAlias        map[string]int
	TableSchemaByName     map[string]*memCom.TableSchema

	NumDimsPerDimWidth common.DimCountsPerDimWidth
	// lookup table from enum dimension index to EnumDict, used for postprocessing
	DimensionEnumReverseDicts map[int][]string
	// this should be the same as generated by datanodes. in the future we should pass
	// it down to datanodes
	DimensionVectorIndex []int
	DimRowBytes          int
	RequestID            string

	// helper used to share common codes
	QCHelper *context.QueryContextHelper
}

// NewQueryContext creates new query context
func NewQueryContext(aql *common.AQLQuery, returnHLLBinary bool, w http.ResponseWriter) *QueryContext {
	ctx := QueryContext{
		AQLQuery:                  aql,
		ReturnHLLBinary:           returnHLLBinary,
		Writer:                    w,
		DimensionEnumReverseDicts: make(map[int][]string),
	}
	ctx.QCHelper = &context.QueryContextHelper{
		QCOptions: &ctx,
	}
	return &ctx
}

// GetRewrittenQuery get the rewritten query after query parsing
func (qc *QueryContext) GetRewrittenQuery() common.AQLQuery {
	newQuery := *qc.AQLQuery
	for i, measure := range newQuery.Measures {
		if measure.ExprParsed != nil {
			measure.Expr = measure.ExprParsed.String()
			newQuery.Measures[i] = measure
		}
	}

	for i, join := range newQuery.Joins {
		for j := range join.Conditions {
			if j < len(join.ConditionsParsed) && join.ConditionsParsed[j] != nil {
				join.Conditions[j] = join.ConditionsParsed[j].String()
			}
		}
		newQuery.Joins[i] = join
	}

	for i, dim := range newQuery.Dimensions {
		if dim.ExprParsed != nil {
			dim.Expr = dim.ExprParsed.String()
			newQuery.Dimensions[i] = dim
		}
	}

	for i := range newQuery.Filters {
		if i < len(newQuery.FiltersParsed) && newQuery.FiltersParsed[i] != nil {
			newQuery.Filters[i] = newQuery.FiltersParsed[i].String()
		}
	}

	for i, measure := range newQuery.SupportingMeasures {
		if measure.ExprParsed != nil {
			measure.Expr = measure.ExprParsed.String()
			newQuery.SupportingMeasures[i] = measure
		}
	}

	for i, dim := range newQuery.SupportingDimensions {
		if dim.ExprParsed != nil {
			dim.Expr = dim.ExprParsed.String()
			newQuery.SupportingDimensions[i] = dim
		}
	}
	return newQuery
}

// Compile parses expressions into ast, load schema from schema reader, resolve types,
// and collects meta data needed by post processing
func (qc *QueryContext) Compile(tableSchemaReader memCom.TableSchemaReader) {
	qc.readSchema(tableSchemaReader)
	defer qc.releaseSchema()
	if qc.Error != nil {
		return
	}

	qc.processJoins()
	if qc.Error != nil {
		return
	}

	qc.processMeasures()
	if qc.Error != nil {
		return
	}
	qc.processDimensions()
	if qc.Error != nil {
		return
	}

	qc.processFilters()
	if qc.Error != nil {
		return
	}

	qc.sortDimensionColumns()
	return
}

func (qc *QueryContext) readSchema(tableSchemaReader memCom.TableSchemaReader) {
	qc.Tables = make([]*memCom.TableSchema, 1+len(qc.AQLQuery.Joins))
	qc.TableIDByAlias = make(map[string]int)
	qc.TableSchemaByName = make(map[string]*memCom.TableSchema)

	tableSchemaReader.RLock()
	defer tableSchemaReader.RUnlock()

	var (
		err    error
		schema *memCom.TableSchema
	)
	// Main table.
	schema, err = tableSchemaReader.GetSchema(qc.AQLQuery.Table)
	if err != nil {
		qc.Error = utils.StackError(err, "unknown main table %s", qc.AQLQuery.Table)
		return
	}
	qc.TableSchemaByName[qc.AQLQuery.Table] = schema
	schema.RLock()
	qc.Tables[0] = schema

	qc.TableIDByAlias[qc.AQLQuery.Table] = 0

	// Foreign tables.
	for i, join := range qc.AQLQuery.Joins {
		schema, err = tableSchemaReader.GetSchema(join.Table)
		if err != nil {
			qc.Error = utils.StackError(err, "unknown join table %s", join.Table)
			return
		}

		if qc.TableSchemaByName[join.Table] == nil {
			qc.TableSchemaByName[join.Table] = schema
			// Prevent double locking.
			schema.RLock()
		}

		qc.Tables[1+i] = schema

		alias := join.Alias
		if alias == "" {
			alias = join.Table
		}
		_, exists := qc.TableIDByAlias[alias]
		if exists {
			qc.Error = utils.StackError(nil, "table alias %s is redefined", alias)
			return
		}
		qc.TableIDByAlias[alias] = 1 + i
	}
}

func (qc *QueryContext) releaseSchema() {
	for _, schema := range qc.TableSchemaByName {
		schema.RUnlock()
	}
}

func (qc *QueryContext) processJoins() {
	var err error
	for i, join := range qc.AQLQuery.Joins {
		join.ConditionsParsed = make([]expr.Expr, len(join.Conditions))
		for j, cond := range join.Conditions {
			join.ConditionsParsed[j], err = expr.ParseExpr(cond)
			if err != nil {
				qc.Error = utils.StackError(err, "Failed to parse join condition: %s", cond)
				return
			}
			join.ConditionsParsed[j] = expr.Rewrite(qc, join.ConditionsParsed[j])
			if qc.Error != nil {
				return
			}
		}
		qc.AQLQuery.Joins[i] = join
	}
}

func (qc *QueryContext) processFilters() {
	var err error

	qc.AQLQuery.FiltersParsed = make([]expr.Expr, len(qc.AQLQuery.Filters))
	for i, filter := range qc.AQLQuery.Filters {
		qc.AQLQuery.FiltersParsed[i], err = expr.ParseExpr(filter)
		if err != nil {
			qc.Error = utils.StackError(err, "Failed to parse filter %s", filter)
			return
		}
		qc.AQLQuery.FiltersParsed[i] = expr.Rewrite(qc, qc.AQLQuery.FiltersParsed[i])
		if qc.Error != nil {
			return
		}
	}

	qc.AQLQuery.FiltersParsed = qc.QCHelper.NormalizeAndFilters(qc.AQLQuery.FiltersParsed)
}

func (qc *QueryContext) processMeasures() {
	var err error

	for i, measure := range qc.AQLQuery.Measures {
		measure.ExprParsed, err = expr.ParseExpr(measure.Expr)
		if err != nil {
			qc.Error = utils.StackError(err, "Failed to parse measure: %s", measure.Expr)
			return
		}
		measure.ExprParsed = expr.Rewrite(qc, measure.ExprParsed)
		if qc.Error != nil {
			return
		}

		measure.FiltersParsed = make([]expr.Expr, len(measure.Filters))
		for j, filter := range measure.Filters {
			measure.FiltersParsed[j], err = expr.ParseExpr(filter)
			if err != nil {
				qc.Error = utils.StackError(err, "Failed to parse measure filter %s", filter)
				return
			}
			measure.FiltersParsed[j] = expr.Rewrite(qc, measure.FiltersParsed[j])
			if qc.Error != nil {
				return
			}
		}
		measure.FiltersParsed = qc.QCHelper.NormalizeAndFilters(measure.FiltersParsed)
		qc.AQLQuery.Measures[i] = measure
	}

	// ony support 1 measure for now
	if len(qc.AQLQuery.Measures) != 1 {
		qc.Error = utils.StackError(nil, "expect one measure per query, but got %d",
			len(qc.AQLQuery.Measures))
		return
	}

	if _, ok := qc.AQLQuery.Measures[0].ExprParsed.(*expr.NumberLiteral); ok {
		qc.IsNonAggregationQuery = true
		// in case user forgot to provide limit
		if qc.AQLQuery.Limit == 0 {
			qc.AQLQuery.Limit = nonAggregationQueryLimit
		}
		return
	}

	aggregate, ok := qc.AQLQuery.Measures[0].ExprParsed.(*expr.Call)
	if !ok {
		qc.Error = utils.StackError(nil, "expect aggregate function, but got %s",
			qc.AQLQuery.Measures[0].Expr)
		return
	}

	if len(aggregate.Args) != 1 {
		qc.Error = utils.StackError(nil,
			"expect one parameter for aggregate function %s, but got %d",
			aggregate.Name, len(aggregate.Args))
		return
	}

	if qc.ReturnHLLBinary && aggregate.Name != expr.HllCallName {
		qc.Error = utils.StackError(nil, "expect hll aggregate function as client specify 'Accept' as "+
			"'application/hll', but got %s",
			qc.AQLQuery.Measures[0].Expr)
		return
	}
}

func (qc *QueryContext) processDimensions() {
	rawDims := qc.AQLQuery.Dimensions
	qc.AQLQuery.Dimensions = []common.Dimension{}
	qc.DimensionVectorIndex = make([]int, len(rawDims))
	for _, dim := range rawDims {
		var err error
		dim.ExprParsed, err = expr.ParseExpr(dim.Expr)
		if err != nil {
			qc.Error = utils.StackError(err, "Failed to parse dimension: %s", dim.Expr)
			return
		}
		if _, ok := dim.ExprParsed.(*expr.Wildcard); ok && qc.IsNonAggregationQuery {
			qc.AQLQuery.Dimensions = append(qc.AQLQuery.Dimensions, qc.getAllColumnsDimension()...)
		} else {
			qc.AQLQuery.Dimensions = append(qc.AQLQuery.Dimensions, dim)
		}
	}

	for idx, dim := range qc.AQLQuery.Dimensions {
		dim.ExprParsed = expr.Rewrite(qc, dim.ExprParsed)
		if vr, ok := dim.ExprParsed.(*expr.VarRef); ok {
			if len(vr.EnumReverseDict) > 0 {
				qc.DimensionEnumReverseDicts[idx] = vr.EnumReverseDict
			}
		}
		qc.AQLQuery.Dimensions[idx] = dim
	}
}

func (qc *QueryContext) sortDimensionColumns() {
	orderedIndex := 0
	numDimensions := len(qc.AQLQuery.Dimensions)
	qc.DimensionVectorIndex = make([]int, numDimensions)
	byteWidth := 1 << uint(len(qc.NumDimsPerDimWidth)-1)
	for byteIndex := range qc.NumDimsPerDimWidth {
		for originIndex, dim := range qc.AQLQuery.Dimensions {
			dataBytes := common.GetDimensionDataBytes(dim.ExprParsed)
			if dataBytes == byteWidth {
				// record value offset, null offset pair
				// null offsets will have to add total dim bytes later
				qc.DimensionVectorIndex[originIndex] = orderedIndex
				qc.NumDimsPerDimWidth[byteIndex]++
				qc.DimRowBytes += dataBytes
				orderedIndex++
			}
		}
		byteWidth >>= 1
	}
	// plus one byte per dimension column for validity
	qc.DimRowBytes += numDimensions
}

func (qc *QueryContext) getAllColumnsDimension() (columns []common.Dimension) {
	// only main table columns wildcard match supported
	for _, column := range qc.Tables[0].Schema.Columns {
		if !column.Deleted && column.Type != metaCom.GeoShape {
			columns = append(columns, common.Dimension{
				Expr:       column.Name,
				ExprParsed: &expr.VarRef{Val: column.Name},
			})
		}
	}
	return
}

// Rewrite walks the expresison AST and resolves data types bottom up.
// In addition it also translates enum strings and rewrites their predicates.
func (qc *QueryContext) Rewrite(expression expr.Expr) expr.Expr {
	return qc.QCHelper.Rewrite(expression)
}

func (qc *QueryContext) InitQCHelper() {
	qc.QCHelper = &context.QueryContextHelper{
		QCOptions: qc,
	}
}

func (qc *QueryContext) GetSchema(tableID int) *memCom.TableSchema {
	return qc.Tables[tableID]
}

func (qc *QueryContext) GetTableID(alias string) (int, bool) {
	id, exists := qc.TableIDByAlias[alias]
	return id, exists
}

func (qc *QueryContext) GetQuery() *common.AQLQuery {
	return qc.AQLQuery
}

func (qc *QueryContext) SetError(err error) {
	qc.Error = err
}

func (qc *QueryContext) IsDataOnly() bool {
	return false
}
